import os
import sys
import gzip
import pickle
from collections import defaultdict, Counter
import pysam
import pybedtools
from pylab import *


halfwindow = 2000
directory = "/osc-fs_home/mdehoon/Data/CASPARs/"

bottom = 0.21
top = 0.72
left = 0.10
right = 0.95
wspace = 0.40

keep_targets = set(['mRNA', 'lncRNA', 'gencode', 'fantomcat', 'novel',
                    'genome', 'MALAT1', 'TERC',  'RMRP', 'RPPH', 
                   ])

skip_targets = set(['chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'snoRNA', 'snar',
                    'yRNA', 'scaRNA', 'vRNA', 'histone', 'snhg',
                    ])

skip_categories = set(["last_exon",
                       'presnoRNA', 'prescaRNA', 'presnRNA', 'pretRNA',
                       "FANTOM5_enhancer", "roadmap_enhancer", "roadmap_dyadic",
                       "novel_enhancer_CAGE", "novel_enhancer_HiSeq"])

proximal_forward_categories = set(["sense_proximal", "sense_upstream"])
distal_forward_categories = set(["sense_distal", "sense_distal_upstream"])
reverse_categories = set(["antisense", "prompt",
                          "antisense_distal", "antisense_distal_upstream"])

tsss = {}
lines = pybedtools.BedTool("genes.FANTOM_CAT.THP-1.gff")
for line in lines:
    gene = line.attrs['ID']
    tss = int(line.attrs['TSS'])
    tsss[gene] = tss


def parse_alignments(dataset, library, side="5"):
    filename = "%s.bam" % library
    subdirectory = os.path.join(directory, dataset, "Mapping")
    path = os.path.join(subdirectory, filename)
    print("Reading", path)
    alignments = pysam.AlignmentFile(path)
    if dataset in ("StartSeq", "HiSeq", "CAGE"):
        yield from alignments
    elif dataset == "MiSeq":
        if side == "5":
            for alignment1 in alignments:
                alignment2 = next(alignments)
                assert alignment1.is_read1
                assert alignment2.is_read2
                yield alignment1
        elif side == "3":
            for alignment1 in alignments:
                alignment2 = next(alignments)
                assert alignment1.is_read1
                assert alignment2.is_read2
                if alignment1.is_unmapped:
                    position = None
                elif alignment1.is_reverse:
                    position = alignment2.reference_start
                else:
                    position = alignment2.reference_end - 1
                yield alignment1, position
    else:
        raise Exception("Unknown dataset %s" % dataset)


def calculate_profile(dataset, library):
    proximal = 0
    distal = 0
    profiles = defaultdict(lambda: zeros(2*halfwindow+1))
    alignments = parse_alignments(dataset, library, "5")
    for alignment in alignments:
        if alignment.is_unmapped:
            continue
        try:
            target = alignment.get_tag("XT")
        except KeyError:
            target = None
        if target in skip_targets:
            continue
        assert target in keep_targets
        try:
            category = alignment.get_tag("XA")
        except KeyError:
            continue
        if category in skip_categories:
            continue
        multimap = alignment.get_tag("NH")
        weight = 1.0 / multimap
        distance = alignment.get_tag("XD")
        if category in proximal_forward_categories:
            assert abs(distance) <= halfwindow
            index = distance + halfwindow
            proximal += weight
            genes = alignment.get_tag("XG").split(",")
            weight /= len(genes)
            for gene in genes:
                profiles[gene][index] += weight
        elif category in reverse_categories:
            pass
        elif category in distal_forward_categories:
            assert abs(distance) >= halfwindow
            distal += weight
        else:
            raise Exception("Unknown category %s" % category)
    profile = zeros(2*halfwindow+1)
    for gene in profiles:
        denominator = sum(profiles[gene])
        profile += profiles[gene] / denominator
    profile /= sum(profile)
    return profile, proximal, distal


datasets = ("StartSeq", "MiSeq", "HiSeq", "CAGE")
profiles = {}
proximal = {}
distal = {}
for dataset in datasets:
    subdirectory = os.path.join(directory, dataset, "Mapping")
    filenames = os.listdir(subdirectory)
    filenames.sort()
    profiles[dataset] = {}
    proximal[dataset] = {}
    distal[dataset] = {}
    for filename in filenames:
        library, extension = os.path.splitext(filename)
        assert extension == ".bam"
        if dataset == "StartSeq" and library in ("SRR7071454" "SRR7071455"):
            # RNA input control libraries
            continue
        elif dataset == "MiSeq" and not library.startswith("t"):
            # include time course samples only
            continue
        elif dataset == "HiSeq" and library == "t01_r3":
            # negative control library prepared using water instead of RNA
            continue
        profiles[dataset][library], proximal[dataset][library], distal[dataset][library] = calculate_profile(dataset, library)

positions = arange(-halfwindow, +halfwindow+1)

f = figure(figsize=(6.4,2.4))

ax = f.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_ylabel("Fraction of reads", fontsize=8)
ax.set_xlabel("Position of the 5' end with respect to the transcription start site [nucleotides]", fontsize=8)


xm = 2000
ym = 0
bins = 101
i = 0
for dataset in datasets:
    f.add_subplot(1, 4, 1+i)
    i += 1
    yam = zeros(bins)
    proximal_pam = 0
    distal_pam = 0
    for library in profiles[dataset]:
        weights = profiles[dataset][library]
        y, xe = histogram(positions, bins=bins, weights=weights)
        x = (xe[1:] + xe[:-1]) / 2.0
        yam += y
        proximal_pam += proximal[dataset][library]
        distal_pam += distal[dataset][library]
    n = len(profiles[dataset])
    yam /= n
    plot(x, yam, color='black')
    xticks(fontsize=8)
    yticks(fontsize=8)
    if dataset == "StartSeq":
        rna = "Very short capped RNAs"
        name = "Start-Seq"
    elif dataset == "MiSeq":
        rna = "Short capped RNAs"
        name = "paired-end libraries"
    elif dataset == "HiSeq":
        rna = "Short capped RNAs"
        name = "single-end libraries"
    elif dataset == "CAGE":
        rna = "Long capped RNAs"
        name = "CAGE libraries"
    title("%s,\n%s;\nproximal: $N = %d$;\ndistal: $N = %d$" % (rna, name, proximal_pam, distal_pam), fontsize=8)
    ymin, ymax = ylim()
    if ymax > ym:
        ym = ymax

for ax in f.axes[1:5]:
    ax.set_xlim(-xm, xm)
    # ax.set_ylim(0, ym)


subplots_adjust(bottom=bottom, top=top, left=left, right=right, wspace=wspace)

filename = "figure_promoter_distribution_timecourse.svg"
print("Saving figure as %s" % filename)
savefig(filename)

filename = "figure_promoter_distribution_timecourse.png"
print("Saving figure as %s" % filename)
savefig(filename)
